Here is a program that does the variable reuse implementation in a clever way so as to go through the computational graph only once. The assumption is that the compiler cannot recognise common subexpressions and thus any common subexpressions that the user wishes to reuse must be declared as named variables otherwise they will be recomputed.
#include <list> using std::list; namespace AD { class indices { protected: list<unsigned int> liveidx, reuseidx; unsigned int maxlive; indices(const indices&); const indices& operator=(const indices&); public: indices(); ~indices(); unsigned int nextindex(); void delindex(unsigned int idx); }; class active; class atemp { protected: double _val; double _dot; unsigned int _id; static list<atemp> tempvals; static indices tmpidx; atemp(); static atemp& moretemps(); public: atemp(const atemp&); ~atemp(); static void deltemp(const atemp&); const atemp& operator=(const atemp&); bool operator==(const atemp&) const; unsigned int id() const; double val() const; double dot() const; friend const atemp& operator+(const atemp&, const atemp&); friend const atemp& operator+(const active&, const active&); friend const atemp& operator+(const atemp&, const active&); friend const atemp& operator+(const active&, const atemp&); friend const atemp& operator-(const atemp&, const atemp&); friend const atemp& operator-(const active&, const active&); friend const atemp& operator-(const atemp&, const active&); friend const atemp& operator-(const active&, const atemp&); friend const atemp& operator*(const atemp&, const atemp&); friend const atemp& operator*(const active&, const active&); friend const atemp& operator*(const atemp&, const active&); friend const atemp& operator*(const active&, const atemp&); friend const atemp& operator*(double, const atemp&); friend const atemp& operator*(double, const active&); friend const atemp& operator*(const atemp&, double); friend const atemp& operator*(const active&, double); friend const atemp& operator/(double, const atemp&); friend const atemp& operator/(double, const active&); friend const atemp& operator/(const atemp&, double); friend const atemp& operator/(const active&, double); friend const atemp& operator/(const atemp&, const atemp&); friend const atemp& operator/(const active&, const active&); friend const atemp& operator/(const atemp&, const active&); friend const atemp& operator/(const active&, const atemp&); friend const atemp& sqrt(const atemp&); friend const atemp& sqrt(const active&); friend const atemp& square(const atemp&); friend const atemp& square(const active&); friend const atemp& sin(const atemp&); friend const atemp& cos(const atemp&); friend const atemp& sin(const active&); friend const atemp& cos(const active&); }; class active { protected: double _val; double _dot; unsigned int _id; static indices actidx; public: active(); active(const active&); active(double); active(const atemp&); ~active(); const active& operator=(const active&); const active& operator=(const atemp&); const active& operator=(double); unsigned int id() const; double val() const; double dot() const; void set_dot(double); }; }
#include <iostream> #include <fstream> #include <cmath> #include "dyntape.hpp" #include <algorithm> using std::ios; using std::cout; using std::ofstream; using std::find; namespace AD { ofstream opcodes; indices::indices() { liveidx.clear(); reuseidx.clear(); maxlive = 0; } indices::~indices() { liveidx.clear(); reuseidx.clear(); maxlive = 0; } // the next two must never be called indices::indices(const indices& i) { throw -1; } const indices& indices::operator=(const indices& i) { throw -1; } // how to get the next available index unsigned int indices::nextindex() { unsigned int idx; if (!reuseidx.empty()) { idx = reuseidx.front(); reuseidx.pop_front(); } else { idx = ++maxlive; } liveidx.push_back(idx); return idx; } // how to delete an index void indices::delindex(unsigned int idx) { liveidx.remove(idx); reuseidx.push_front(idx); if (reuseidx.size() >= 2) reuseidx.sort(); } // static vars must be redeclared list<atemp> atemp::tempvals; indices atemp::tmpidx; indices active::actidx; // here is the meat of the code atemp::atemp() { _val = 0; _dot = 0; _id = 0; } atemp& atemp::moretemps() { atemp t; t._id = tmpidx.nextindex(); tempvals.push_back(t); return tempvals.back(); } atemp::atemp(const atemp& t) { _val = t._val; _dot = t._dot; _id = t._id; } const atemp& atemp::operator=(const atemp& t) { _val = t._val; _dot = t._dot; _id = t._id; return t; } atemp::~atemp() { _val = 0; _id = 0; _dot = 0; } void atemp::deltemp(const atemp& t) { list<atemp>::iterator iter; iter = find(tempvals.begin(), tempvals.end(), t); if (iter != tempvals.end()) { tmpidx.delindex(iter->_id); tempvals.erase(iter); } } bool atemp::operator==(const atemp& t) const { return (_id == t._id); } unsigned int atemp::id() const { return _id; } double atemp::val() const { return _val; } double atemp::dot() const { return _dot; } unsigned int active::id() const { return _id; } double active::val() const { return _val; } double active::dot() const { return _dot; } // The PLUS OP const atemp& operator+(const atemp& a, const atemp& b) { atemp& r = atemp::moretemps(); r._val = a._val + b._val; r._dot = a._dot + b._dot; opcodes << "ADD\t" << "t[" << a._id << "]\t\t\tt[" << b._id << "]\t\t\tt[" << r._id << "]\n"; atemp::deltemp(a); atemp::deltemp(b); return r; } const atemp& operator+(const active& a, const active& b) { atemp& r = atemp::moretemps(); r._val = a.val() + b.val(); r._dot = a.dot() + b.dot(); opcodes << "ADD\t" << "v[" << a.id() << "]\t\t\tv[" << b.id() << "]\t\t\tt[" << r._id << "]\n"; return r; } const atemp& operator+(const atemp& a, const active& b) { atemp& r = atemp::moretemps(); r._val = a._val + b.val(); r._dot = a._dot + b.dot(); opcodes << "ADD\t" << "t[" << a._id << "]\t\t\tv[" << b.id() << "]\t\t\tt[" << r._id << "]\n"; atemp::deltemp(a); return r; } const atemp& operator+(const active& a, const atemp& b) { atemp& r = atemp::moretemps(); r._val = a.val() + b._val; r._dot = a.dot() + b._dot; opcodes << "ADD\t" << "v[" << a.id() << "]\t\t\tt[" << b._id << "]\t\t\tt[" << r._id << "]\n"; atemp::deltemp(b); return r; } // The MINUS OP const atemp& operator-(const active& a, const atemp& b) { atemp& r = atemp::moretemps(); r._val = a.val() - b._val; r._dot = a.dot() - b._dot; opcodes << "SUB\t" << "v[" << a.id() << "]\t\t\tt[" << b._id << "]\t\t\tt[" << r._id << "]\n"; atemp::deltemp(b); return r; } const atemp& operator-(const atemp& a, const atemp& b) { atemp& r = atemp::moretemps(); r._val = a._val - b._val; r._dot = a._dot - b._dot; opcodes << "SUB\t" << "t[" << a._id << "]\t\t\tt[" << b._id << "]\t\t\tt[" << r._id << "]\n"; atemp::deltemp(a); atemp::deltemp(b); return r; } const atemp& operator-(const atemp& a, const active& b) { atemp& r = atemp::moretemps(); r._val = a._val - b.val(); r._dot = a._dot - b.dot(); opcodes << "SUB\t" << "t[" << a._id << "]\t\t\tv[" << b.id() << "]\t\t\tt[" << r._id << "]\n"; atemp::deltemp(a); return r; } const atemp& operator-(const active& a, const active& b) { atemp& r = atemp::moretemps(); r._val = a.val() - b.val(); r._dot = a.dot() - b.dot(); opcodes << "SUB\t" << "v[" << a.id() << "]\t\t\tv[" << b.id() << "]\t\t\tt[" << r._id << "]\n"; return r; } // The MUL OP const atemp& operator*(const active& a, const atemp& b) { atemp& r = atemp::moretemps(); r._val = a.val() * b._val; r._dot = a.dot()*b._val + b._dot*a.val(); opcodes << "MUL\t" << "v[" << a.id() << "]\t\t\tt[" << b._id << "]\t\t\tt[" << r._id << "]\n"; atemp::deltemp(b); return r; } const atemp& operator*(const atemp& a, const atemp& b) { atemp& r = atemp::moretemps(); r._val = a._val * b._val; r._dot = a._dot*b._val + b._dot*a._val; opcodes << "MUL\t" << "t[" << a._id << "]\t\t\tt[" << b._id << "]\t\t\tt[" << r._id << "]\n"; atemp::deltemp(a); atemp::deltemp(b); return r; } const atemp& operator*(const atemp& a, const active& b) { atemp& r = atemp::moretemps(); r._val = a._val * b.val(); r._dot = a._dot*b.val() + b.dot()*a._val; opcodes << "MUL\t" << "t[" << a._id << "]\t\t\tv[" << b.id() << "]\t\t\tt[" << r._id << "]\n"; atemp::deltemp(a); return r; } const atemp& operator*(const active& a, const active& b) { atemp& r = atemp::moretemps(); r._val = a.val() * b.val(); r._dot = a.dot()*b.val() + b.dot()*a.val(); opcodes << "MUL\t" << "v[" << a.id() << "]\t\t\tv[" << b.id() << "]\t\t\tt[" << r._id << "]\n"; return r; } const atemp& operator*(double a, const atemp& b) { atemp& r = atemp::moretemps(); r._val = a * b._val; r._dot = a * b._dot; opcodes << "MUL\t" << "$ " << a << "\t\tt[" << b._id << "]\t\t\tt[" << r._id << "]\n"; atemp::deltemp(b); return r; } const atemp& operator*(double a, const active& b) { atemp& r = atemp::moretemps(); r._val = a * b.val(); r._dot = a * b.dot(); opcodes << "MUL\t" << "$ " << a << "\t\tv[" << b.id() << "]\t\t\tt[" << r._id << "]\n"; return r; } const atemp& operator*(const atemp& a, double b) { atemp& r = atemp::moretemps(); r._val = a._val * b; r._dot = a._dot * b; opcodes << "MUL\t" << "t[" << a._id << "]\t\t\t$ " << b << "\t\tt[" << r._id << "]\n"; atemp::deltemp(a); return r; } const atemp& operator*(const active& a, double b) { atemp& r = atemp::moretemps(); r._val = a.val() * b; r._dot = a.dot() * b; opcodes << "MUL\t" << "v[" << a.id() << "]\t\t\t$ " << b << "\t\tt[" << r._id << "]\n"; return r; } // The DIV OP const atemp& operator/(const atemp& a, double b) { atemp& r = atemp::moretemps(); double i_b = 1.0/b; r._val = a._val * i_b; r._dot = a._dot * i_b; opcodes << "DIV\t" << "t[" << a._id << "]\t\t\t$ " << b << "\t\tt[" << r._id << "]\n"; atemp::deltemp(a); return r; } const atemp& operator/(const active& a, double b) { atemp& r = atemp::moretemps(); double i_b = 1.0/b; r._val = a.val() * i_b; r._dot = a.dot() * i_b; opcodes << "DIV\t" << "v[" << a.id() << "]\t\t\t$ " << b << "\t\tt[" << r._id << "]\n"; return r; } const atemp& operator/(double a, const atemp& b) { double i_b = 1.0/b._val; double i_b2 = i_b*i_b; atemp& r = atemp::moretemps(); r._val = a * i_b; r._dot = - a * b._dot * i_b2; opcodes << "DIV\t" << "$ " << a << "\t\tt[" << b._id << "]\t\t\tt[" << r._id << "]\n"; atemp::deltemp(b); return r; } const atemp& operator/(double a, const active& b) { double i_b = 1.0/b.val(); double i_b2 = i_b*i_b; atemp& r = atemp::moretemps(); r._val = a * i_b; r._dot = - a * b.dot() * i_b2; opcodes << "DIV\t" << "$ " << a << "\t\tv[" << b.id() << "]\t\t\tt[" << r._id << "]\n"; return r; } const atemp& operator/(const atemp& a, const atemp& b) { return (a * (1.0/b)); } const atemp& operator/(const atemp& a, const active& b) { return (a * (1.0/b)); } const atemp& operator/(const active& a, const atemp& b) { return (a * (1.0/b)); } const atemp& operator/(const active& a, const active& b) { return (a * (1.0/b)); } // unary operators // square const atemp& square(const atemp& a) { return (a*a); } const atemp& square(const active& a) { return (a*a); } // SQRT OP const atemp& sqrt(const atemp& a) { atemp& r = atemp::moretemps(); r._val = std::sqrt(a._val); r._dot = - 0.5 * a._dot / r._val ; opcodes << "SQRT\t" << "t[" << a._id << "]\t\t\tt[" << r._id << "]\n"; atemp::deltemp(a); return r; } const atemp& sqrt(const active& a) { atemp& r = atemp::moretemps(); r._val = std::sqrt(a.val()); r._dot = - 0.5 * a.dot() / r._val ; opcodes << "SQRT\t" << "v[" << a.id() << "]\t\t\tt[" << r._id << "]\n"; return r; } const atemp& sin(const active& a) { atemp& r = atemp::moretemps(); r._val = std::sin(a.val()); r._dot = std::cos(a.val())*a.dot(); opcodes << "SIN\t" << "v[" << a.id() << "]\t\t\tt[" << r._id << "]\n"; return r; } const atemp& sin(const atemp& a) { atemp& r = atemp::moretemps(); r._val = std::sin(a._val); r._dot = std::cos(a._val)*a._dot; opcodes << "SIN\t" << "t[" << a._id << "]\t\t\tt[" << r._id << "]\n"; atemp::deltemp(a); return r; } const atemp& cos(const active& a) { atemp& r = atemp::moretemps(); r._val = std::cos(a.val()); r._dot = - std::sin(a.val())*a.dot(); opcodes << "COS\t" << "v[" << a.id() << "]\t\t\tt[" << r._id << "]\n"; return r; } const atemp& cos(const atemp& a) { atemp& r = atemp::moretemps(); r._val = std::cos(a._val); r._dot = - std::sin(a._val)*a._dot; opcodes << "COS\t" << "t[" << a._id << "]\t\t\tt[" << r._id << "]\n"; atemp::deltemp(a); return r; } active::active() { _val = 0; _dot = 0; _id = 0; } active::active(const active& a) { _id = actidx.nextindex(); _val = a._val; _dot = a._dot; opcodes << "SET\t" << "v[" << a._id << "]\t\t\tv[" << _id << "]\n"; } active::active(double a) { _id = actidx.nextindex(); _val = a; _dot = 0; opcodes << "SET\t" << "$ " << a << "\t\tv[" << _id << "]\n"; } active::active(const atemp& a) { _id = actidx.nextindex(); _val = a.val(); _dot = a.dot(); opcodes << "SET\t" << "t[" << a.id() << "]\t\t\tv[" << _id << "]\n"; atemp::deltemp(a); } active::~active() { actidx.delindex(_id); } const active& active::operator=(const atemp& a) { if (_id == 0) _id = actidx.nextindex(); _val = a.val(); _dot = a.dot(); opcodes << "SET\t" << "t[" << a.id() << "]\t\t\tv[" << _id << "]\n"; atemp::deltemp(a); return *this; } const active& active::operator=(const active& a) { if (_id == 0) _id = actidx.nextindex(); _val = a._val; _dot = a._dot; opcodes << "SET\t" << "v[" << a._id << "]\t\t\tv[" << _id << "]\n"; return a; } const active& active::operator=(double a) { if (_id == 0) _id = actidx.nextindex(); _val = a; _dot = 0; opcodes << "SET\t" << "$ " << a << "\t\tv[" << _id << "]\n"; return *this; } void active::set_dot(double d) { _dot = d; } } // namespace using AD::active; using AD::opcodes; active myfunc2(const active& x1, const active& x2, const active& x3, const active& x4) { active r, u, c, s, cs; u = 0.5 * x3 * x4; s = sin(u); c = cos(u); cs = c * s; r = ( square(x2 + x1) * square(u) * (u - cs)/ square(u*c - s) + square(x2 - x1)* (u + cs)/square(s)) * x3 / 8.0; return r; } int main() { opcodes.open("opcodess.1"); opcodes.setf(ios::scientific); opcodes.precision(6); active y, x1, x2, x3, x4; x1 = 0; x2 = 1.0/sqrt(3); x3 = M_PI; x4 = 1.5; x4.set_dot(1.0); y = myfunc2(x1, x2, x3, x4); cout << "y (val) = " << y.val() << "\ty (dot) = " << y.dot() << "\n"; opcodes.close(); return 0; }