Exercise 2.4: Variable Reuse

Here is a program that does the variable reuse implementation in a clever way so as to go through the computational graph only once. The assumption is that the compiler cannot recognise common subexpressions and thus any common subexpressions that the user wishes to reuse must be declared as named variables otherwise they will be recomputed.

dyntape.hpp


C++ code colored by C++2HTML
#include <list>

using std::list;

namespace AD {

class indices {
protected:
    list<unsigned int> liveidx, reuseidx;
    unsigned int maxlive;
    indices(const indices&);
    const indices& operator=(const indices&);
public:
    indices();
    ~indices();
    unsigned int nextindex();
    void delindex(unsigned int idx);
};

class active;

class atemp {
protected:
    double _val;
    double _dot;
    unsigned int _id;
    static list<atemp> tempvals;
    static indices tmpidx;
    atemp();
    static atemp& moretemps();
public:
    atemp(const atemp&);
    ~atemp();
    static void deltemp(const atemp&);
    const atemp& operator=(const atemp&);
    bool operator==(const atemp&) const;
    unsigned int id() const;
    double val() const;
    double dot() const;
    friend const atemp& operator+(const atemp&, const atemp&);
    friend const atemp& operator+(const active&, const active&);
    friend const atemp& operator+(const atemp&, const active&);
    friend const atemp& operator+(const active&, const atemp&);
    friend const atemp& operator-(const atemp&, const atemp&);
    friend const atemp& operator-(const active&, const active&);
    friend const atemp& operator-(const atemp&, const active&);
    friend const atemp& operator-(const active&, const atemp&);
    friend const atemp& operator*(const atemp&, const atemp&);
    friend const atemp& operator*(const active&, const active&);
    friend const atemp& operator*(const atemp&, const active&);
    friend const atemp& operator*(const active&, const atemp&);
    friend const atemp& operator*(double, const atemp&);
    friend const atemp& operator*(double, const active&);
    friend const atemp& operator*(const atemp&, double);
    friend const atemp& operator*(const active&, double);
    friend const atemp& operator/(double, const atemp&);
    friend const atemp& operator/(double, const active&);
    friend const atemp& operator/(const atemp&, double);
    friend const atemp& operator/(const active&, double);
    friend const atemp& operator/(const atemp&, const atemp&);
    friend const atemp& operator/(const active&, const active&);
    friend const atemp& operator/(const atemp&, const active&);
    friend const atemp& operator/(const active&, const atemp&);
    friend const atemp& sqrt(const atemp&);
    friend const atemp& sqrt(const active&);
    friend const atemp& square(const atemp&);
    friend const atemp& square(const active&);
    friend const atemp& sin(const atemp&);
    friend const atemp& cos(const atemp&);
    friend const atemp& sin(const active&);
    friend const atemp& cos(const active&);
};

class active {
protected:
    double _val;
    double _dot;
    unsigned int _id;
    static indices actidx;
public:
    active();
    active(const active&);
    active(double);
    active(const atemp&);
    ~active();
    const active& operator=(const active&);
    const active& operator=(const atemp&);
    const active& operator=(double);
    unsigned int id() const;
    double val() const;
    double dot() const;
    void set_dot(double);
};

}

dyntape.cpp


C++ code colored by C++2HTML
#include <iostream>
#include <fstream>
#include <cmath>
#include "dyntape.hpp"
#include <algorithm>

using std::ios;
using std::cout;
using std::ofstream;
using std::find;

namespace AD {

ofstream opcodes;

indices::indices() {
    liveidx.clear();
    reuseidx.clear();
    maxlive = 0;
}

indices::~indices() {
    liveidx.clear();
    reuseidx.clear();
    maxlive = 0;
}

// the next two must never be called
indices::indices(const indices& i) {
    throw -1;
}

const indices& indices::operator=(const indices& i) {
    throw -1;
}

// how to get the next available index
unsigned int indices::nextindex() {
    unsigned int idx;
    if (!reuseidx.empty()) {
	idx = reuseidx.front();
	reuseidx.pop_front();
    } else {
	idx = ++maxlive;
    }
    liveidx.push_back(idx);
    return idx;
}

// how to delete an index
void indices::delindex(unsigned int idx) {
    liveidx.remove(idx);
    reuseidx.push_front(idx);
    if (reuseidx.size() >= 2)
	reuseidx.sort();
}
   

// static vars must be redeclared
list<atemp> atemp::tempvals;
indices atemp::tmpidx;
indices active::actidx;

// here is the meat of the code
atemp::atemp() {
    _val = 0;
    _dot = 0;
    _id = 0;
}

atemp& atemp::moretemps() {
    atemp t;
    t._id = tmpidx.nextindex();
    tempvals.push_back(t);
    return tempvals.back();
}

atemp::atemp(const atemp& t) {
    _val = t._val;
    _dot = t._dot;
    _id = t._id;
   
}

const atemp& atemp::operator=(const atemp& t) {
    _val = t._val;
    _dot = t._dot;
    _id = t._id;
    return t;
}

atemp::~atemp() {
    _val = 0;
    _id = 0;
    _dot = 0;
}

void atemp::deltemp(const atemp& t) {
    list<atemp>::iterator iter;
    iter = find(tempvals.begin(), tempvals.end(), t);
    if (iter != tempvals.end()) {
	tmpidx.delindex(iter->_id);
	tempvals.erase(iter);
    }
}


bool atemp::operator==(const atemp& t) const {
    return (_id == t._id);
}


unsigned int atemp::id() const {
    return _id;
}

double atemp::val() const {
    return _val;
}

double atemp::dot() const {
    return _dot;
}

unsigned int active::id() const {
    return _id;
}

double active::val() const {
    return _val;
}

double active::dot() const {
    return _dot;
}

// The PLUS OP
const atemp& operator+(const atemp& a, const atemp& b) {
    atemp& r = atemp::moretemps();
    r._val = a._val + b._val;
    r._dot = a._dot + b._dot;
    opcodes << "ADD\t" << "t[" << a._id << "]\t\t\tt[" << b._id << "]\t\t\tt[" << r._id << "]\n";
    atemp::deltemp(a);
    atemp::deltemp(b);
    return r;
}

const atemp& operator+(const active& a, const active& b) {
    atemp& r = atemp::moretemps();
    r._val = a.val() + b.val();
    r._dot = a.dot() + b.dot();
    opcodes << "ADD\t" << "v[" << a.id() << "]\t\t\tv[" << b.id() << "]\t\t\tt[" << r._id << "]\n";
    return r;
}

const atemp& operator+(const atemp& a, const active& b) {
    atemp& r = atemp::moretemps();
    r._val = a._val + b.val();
    r._dot = a._dot + b.dot();
    opcodes << "ADD\t" << "t[" << a._id << "]\t\t\tv[" << b.id() << "]\t\t\tt[" << r._id << "]\n";
    atemp::deltemp(a);
    return r;
}

const atemp& operator+(const active& a, const atemp& b) {
    atemp& r = atemp::moretemps();
    r._val = a.val() + b._val;
    r._dot = a.dot() + b._dot;
    opcodes << "ADD\t" << "v[" << a.id() << "]\t\t\tt[" << b._id << "]\t\t\tt[" << r._id << "]\n";
    atemp::deltemp(b);
    return r;
}

// The MINUS OP
const atemp& operator-(const active& a, const atemp& b) {
    atemp& r = atemp::moretemps();
    r._val = a.val() - b._val;
    r._dot = a.dot() - b._dot;
    opcodes << "SUB\t" << "v[" << a.id() << "]\t\t\tt[" << b._id << "]\t\t\tt[" << r._id << "]\n";
    atemp::deltemp(b);
    return r;
}

const atemp& operator-(const atemp& a, const atemp& b) {
    atemp& r = atemp::moretemps();
    r._val = a._val - b._val;
    r._dot = a._dot - b._dot;
    opcodes << "SUB\t" << "t[" << a._id << "]\t\t\tt[" << b._id << "]\t\t\tt[" << r._id << "]\n";
    atemp::deltemp(a);
    atemp::deltemp(b);
    return r;
}

const atemp& operator-(const atemp& a, const active& b) {
    atemp& r = atemp::moretemps();
    r._val = a._val - b.val();
    r._dot = a._dot - b.dot();
    opcodes << "SUB\t" << "t[" << a._id << "]\t\t\tv[" << b.id() << "]\t\t\tt[" << r._id << "]\n";
    atemp::deltemp(a);
    return r;
}

const atemp& operator-(const active& a, const active& b) {
    atemp& r = atemp::moretemps();
    r._val = a.val() - b.val();
    r._dot = a.dot() - b.dot();
    opcodes << "SUB\t" << "v[" << a.id() << "]\t\t\tv[" << b.id() << "]\t\t\tt[" << r._id << "]\n";
    return r;
}


// The MUL OP
const atemp& operator*(const active& a, const atemp& b) {
    atemp& r = atemp::moretemps();
    r._val = a.val() * b._val;
    r._dot = a.dot()*b._val + b._dot*a.val();
    opcodes << "MUL\t" << "v[" << a.id() << "]\t\t\tt[" << b._id << "]\t\t\tt[" << r._id << "]\n";
    atemp::deltemp(b);
    return r;
}

const atemp& operator*(const atemp& a, const atemp& b) {
    atemp& r = atemp::moretemps();
    r._val = a._val * b._val;
    r._dot = a._dot*b._val + b._dot*a._val;
    opcodes << "MUL\t" << "t[" << a._id << "]\t\t\tt[" << b._id << "]\t\t\tt[" << r._id << "]\n";
    atemp::deltemp(a);
    atemp::deltemp(b);
    return r;
}

const atemp& operator*(const atemp& a, const active& b) {
    atemp& r = atemp::moretemps();
    r._val = a._val * b.val();
    r._dot = a._dot*b.val() + b.dot()*a._val;
    opcodes << "MUL\t" << "t[" << a._id << "]\t\t\tv[" << b.id() << "]\t\t\tt[" << r._id << "]\n";
    atemp::deltemp(a);
    return r;
}

const atemp& operator*(const active& a, const active& b) {
    atemp& r = atemp::moretemps();
    r._val = a.val() * b.val();
    r._dot = a.dot()*b.val() + b.dot()*a.val();
    opcodes << "MUL\t" << "v[" << a.id() << "]\t\t\tv[" << b.id() << "]\t\t\tt[" << r._id << "]\n";
    return r;
}

const atemp& operator*(double a, const atemp& b) {
    atemp& r = atemp::moretemps();
    r._val = a * b._val;
    r._dot = a * b._dot;
    opcodes << "MUL\t" << "$ " << a << "\t\tt[" << b._id << "]\t\t\tt[" << r._id << "]\n";
    atemp::deltemp(b);
    return r;
}

const atemp& operator*(double a, const active& b) {
    atemp& r = atemp::moretemps();
    r._val = a * b.val();
    r._dot = a * b.dot();
    opcodes << "MUL\t" << "$ " << a << "\t\tv[" << b.id() << "]\t\t\tt[" << r._id << "]\n";
    return r;
}

const atemp& operator*(const atemp& a, double b) {
    atemp& r = atemp::moretemps();
    r._val = a._val * b;
    r._dot = a._dot * b;
    opcodes << "MUL\t" << "t[" << a._id << "]\t\t\t$ " << b << "\t\tt[" << r._id << "]\n";
    atemp::deltemp(a);
    return r;
}

const atemp& operator*(const active& a, double b) {
    atemp& r = atemp::moretemps();
    r._val = a.val() * b;
    r._dot = a.dot() * b;
    opcodes << "MUL\t" << "v[" << a.id() << "]\t\t\t$ " << b << "\t\tt[" << r._id << "]\n";
    return r;
}

// The DIV OP
const atemp& operator/(const atemp& a, double b) {
    atemp& r = atemp::moretemps();
    double i_b = 1.0/b;
    r._val = a._val * i_b;
    r._dot = a._dot * i_b;
    opcodes << "DIV\t" << "t[" << a._id << "]\t\t\t$ " << b << "\t\tt[" << r._id << "]\n";
    atemp::deltemp(a);
    return r;
}

const atemp& operator/(const active& a, double b) {
    atemp& r = atemp::moretemps();
    double i_b = 1.0/b;
    r._val = a.val() * i_b;
    r._dot = a.dot() * i_b;
    opcodes << "DIV\t" << "v[" << a.id() << "]\t\t\t$ " << b << "\t\tt[" << r._id << "]\n";
    return r;
}

const atemp& operator/(double a, const atemp& b) {
    double i_b = 1.0/b._val; 
    double i_b2 = i_b*i_b;
    atemp& r = atemp::moretemps();
    r._val = a * i_b;
    r._dot = - a * b._dot * i_b2;
    opcodes << "DIV\t" << "$ " << a << "\t\tt[" << b._id << "]\t\t\tt[" << r._id << "]\n";
    atemp::deltemp(b);
    return r;
}

const atemp& operator/(double a, const active& b) {
    double i_b = 1.0/b.val(); 
    double i_b2 = i_b*i_b;
    atemp& r = atemp::moretemps();
    r._val = a * i_b;
    r._dot = - a * b.dot() * i_b2;
    opcodes << "DIV\t" << "$ " << a << "\t\tv[" << b.id() << "]\t\t\tt[" << r._id << "]\n";
    return r;
}

const atemp& operator/(const atemp& a, const atemp& b) {
    return (a * (1.0/b));
}

const atemp& operator/(const atemp& a, const active& b) {
    return (a * (1.0/b));
}

const atemp& operator/(const active& a, const atemp& b) {
    return (a * (1.0/b));
}

const atemp& operator/(const active& a, const active& b) {
    return (a * (1.0/b));
}

// unary operators
// square 
const atemp& square(const atemp& a) {
    return (a*a);
}

const atemp& square(const active& a) {
    return (a*a);
}

// SQRT OP
const atemp& sqrt(const atemp& a) {
    atemp& r = atemp::moretemps();
    r._val = std::sqrt(a._val);
    r._dot = - 0.5 * a._dot / r._val ;
    opcodes << "SQRT\t" << "t[" << a._id << "]\t\t\tt[" << r._id << "]\n";
    atemp::deltemp(a);
    return r;
}

const atemp& sqrt(const active& a) {
    atemp& r = atemp::moretemps();
    r._val = std::sqrt(a.val());
    r._dot = - 0.5 * a.dot() / r._val ;
    opcodes << "SQRT\t" << "v[" << a.id() << "]\t\t\tt[" << r._id << "]\n";
    return r;
}


const atemp& sin(const active& a) {
    atemp& r = atemp::moretemps();
    r._val = std::sin(a.val());
    r._dot = std::cos(a.val())*a.dot();
    opcodes << "SIN\t" << "v[" << a.id() << "]\t\t\tt[" << r._id << "]\n";
    return r;
}

const atemp& sin(const atemp& a) {
    atemp& r = atemp::moretemps();
    r._val = std::sin(a._val);
    r._dot = std::cos(a._val)*a._dot;
    opcodes << "SIN\t" << "t[" << a._id << "]\t\t\tt[" << r._id << "]\n";
    atemp::deltemp(a);
    return r;
}

const atemp& cos(const active& a) {
    atemp& r = atemp::moretemps();
    r._val = std::cos(a.val());
    r._dot = - std::sin(a.val())*a.dot();
    opcodes << "COS\t" << "v[" << a.id() << "]\t\t\tt[" << r._id << "]\n";
    return r;
}

const atemp& cos(const atemp& a) {
    atemp& r = atemp::moretemps();
    r._val = std::cos(a._val);
    r._dot = - std::sin(a._val)*a._dot;
    opcodes << "COS\t" << "t[" << a._id << "]\t\t\tt[" << r._id << "]\n";
    atemp::deltemp(a);
    return r;
}

active::active() {
    _val = 0;
    _dot = 0;
    _id = 0;
}

active::active(const active& a) {
    _id = actidx.nextindex();
    _val = a._val;
    _dot = a._dot;
    opcodes << "SET\t" << "v[" << a._id << "]\t\t\tv[" << _id << "]\n";
}

active::active(double a) {
    _id = actidx.nextindex();
    _val = a;
    _dot = 0;
    opcodes << "SET\t" << "$ " << a << "\t\tv[" << _id << "]\n";
}

active::active(const atemp& a) {
    _id = actidx.nextindex();
    _val = a.val();
    _dot = a.dot();
    opcodes << "SET\t" << "t[" << a.id() << "]\t\t\tv[" << _id << "]\n";
    atemp::deltemp(a);
}

active::~active() {
    actidx.delindex(_id);
}

const active& active::operator=(const atemp& a) {
    if (_id == 0) _id = actidx.nextindex();
    _val = a.val();
    _dot = a.dot();
    opcodes << "SET\t" << "t[" << a.id() << "]\t\t\tv[" << _id << "]\n";
    atemp::deltemp(a);
    return *this;
}

const active& active::operator=(const active& a) {
    if (_id == 0) _id = actidx.nextindex();
    _val = a._val;
    _dot = a._dot;
    opcodes << "SET\t" << "v[" << a._id << "]\t\t\tv[" << _id << "]\n";
    return a;
}

const active& active::operator=(double a) {
    if (_id == 0) _id = actidx.nextindex();
    _val = a;
    _dot = 0;
    opcodes << "SET\t" << "$ " << a << "\t\tv[" << _id << "]\n";
    return *this;
}

void active::set_dot(double d) {
    _dot = d;
}

} // namespace

using AD::active;
using AD::opcodes;

active myfunc2(const active& x1, const active& x2, const active& x3, const active& x4) {
    active r, u, c, s, cs;
    u = 0.5 * x3 * x4;
    s = sin(u);
    c = cos(u);
    cs = c * s;
    r = ( square(x2 + x1) * square(u) * (u - cs)/ square(u*c - s) + square(x2 - x1)* (u + cs)/square(s)) * x3 / 8.0;
    return r;
}

int main() {
    opcodes.open("opcodess.1");
    opcodes.setf(ios::scientific);
    opcodes.precision(6);
    active y, x1, x2, x3, x4;
    x1 = 0;
    x2 = 1.0/sqrt(3);
    x3 = M_PI;
    x4 = 1.5;
    x4.set_dot(1.0);
    y = myfunc2(x1, x2, x3, x4);
    cout << "y (val) = " << y.val() << "\ty (dot) = " << y.dot() << "\n";
    opcodes.close();
    return 0;
}