from inspect import isgeneratorfunction
from functools import partial

__all__ = ['coreturn', 'cofunction', 'costart']

def coreturn(val):
    e = StopIteration()
    e.covalue = val
    raise e

class Cocall(object):
    def __init__(self):
        self.pending = None
    def __call__(self, f, *args, **kwargs):
        self.assert_no_pending()
        self.pending = lambda: f.cocall(self, *args, **kwargs)
        return self
    def assert_no_pending(self):
        if self.pending:
            raise TypeError("pending cocall was not yielded")
    def start_pending(self):
        if self.pending:
            pending = self.pending
            self.pending = None
            yield pending()

class cofunction(object):
    def __init__(self, f):
        if isgeneratorfunction(f):
            self._f = f
        else:
            def _f(*args, **kwargs):
                if 0: yield
                f(*args, **kwargs)
            self._f = _f
    def cocall(self, *args, **kwargs):
        return self._f(*args, **kwargs)
    def __call__(self, *args, **kwargs):
        raise TypeError("cofunctions must be called with cocall or costart")
    def __get__(self, obj, objtype=None):
        return BoundCofunction(self._f, obj)

class BoundCofunction(cofunction):
    def __init__(self, f, obj):
        self._f = partial(f, obj)


def costart(f, *args, **kwargs):
    cocall = Cocall()
    covalue = None
    generator_stack = [f.cocall(cocall, *args, **kwargs)]
    while generator_stack:
        current_generator = generator_stack[-1]
        try:
            val = current_generator.send(covalue)
            if val is cocall:
                covalue = None
                generator_stack.extend(cocall.start_pending())
            else:
                try:
                    cocall.assert_no_pending()
                    covalue = yield val
                except BaseException as e:
                    current_generator.throw(e)
        except StopIteration as e:
            generator_stack.pop()
            if cocall.pending:
                msg = "last cocall in cofunction '%s' was not yielded"
                error = TypeError(msg % current_generator.__name__)
                if generator_stack:
                    generator_stack[-1].throw(error)
                else:
                    raise error
            covalue = getattr(e, "covalue", None)
    coreturn(covalue)